import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt 
import glob
import numpy as np
from matplotlib.ticker import ScalarFormatter
import string
class ScalarFormatterForceFormat(ScalarFormatter):
    def _set_format(self):  
        self.format = "%0.2f" 
      
# =============================================================================
# Processing the DSSAT output to generate a dataframe for plots
# =============================================================================
path=r"\DSSAT outputs"    
all_files = glob.glob(path + "/*.txt")
my_cols=['TRT','PDAT',	'ADAT',	'MDAT',	'Yield'	,'SWC',	'DRCM',	'IR',	'Ng'	,'RU',	'TMAX'	,'TMIN',	'PRCP',	'ET',	'Tr']

Da=[]
for F in all_files:
    file=pd.read_csv(F,delimiter=r'\s+',names=my_cols,header=None,na_values=[-99])
    Da.append(file)
    
DSSAT = pd.concat(Da, axis=0, ignore_index=True,)
DSSAT.dropna(inplace=True)
DSSAT['YYcode']=DSSAT['TRT'].str[-1:].astype(int)
DSSAT['YYcode_interpret']=np.where((DSSAT['YYcode'] == 1) | (DSSAT['YYcode'] == 2), 0, 100)
DSSAT['YYYY']=DSSAT['PDAT'].astype(str).str[:4].astype(int)
DSSAT['year']=DSSAT['YYYY']+DSSAT['YYcode_interpret']
DSSAT.dtypes
DSSAT['GIS_code']=DSSAT['TRT'].str[4:-5]
DSSAT['CNCT']=DSSAT['TRT'].str[:4]
DSSAT['GCM']=DSSAT['TRT'].str[11:-1]
DSSAT['RCP']=DSSAT['TRT'].str[8:-2]
DSSAT['Country']=DSSAT['TRT'].str[:2]
DSSAT['City']=DSSAT['CNCT'].str.strip().replace(dict(zip(["USCa","ItEm" ,"ItFo", "ChIn", "ChGa","ChXi"], ["California","Emilia","Foggia","Inner Mongolia","Gansu","Xinjiang"])),regex=True)

gis_data=pd.read_csv(r'\GIS-Jan2021.csv')
DSSAT_lat_lon=DSSAT.merge(gis_data, how='left', left_on=['TRT'], right_on=['TRT'])

DSSAT_lat_lon.columns
#North_cal latitude starts from 35.8
NC_SC=37
DSSAT_lat_lon.loc[(DSSAT_lat_lon['LAT'] >NC_SC)&(DSSAT_lat_lon['LON'] <-114)]['City']
DSSAT_lat_lon['City'][(DSSAT_lat_lon['LAT'] > NC_SC)&(DSSAT_lat_lon['LON'] <-114)]="Northern California" 
DSSAT_lat_lon['City'][(DSSAT_lat_lon['LAT'] <= NC_SC)&(DSSAT_lat_lon['LON'] <-114)]="Southern California" 
col_list = ['year','City','GCM','RCP','Yield']
allcities=['California', 'Emilia','Foggia','Xinjiang','Gansu','Inner Mongolia']

# =============================================================================
# =============================================================================
# # Density plots ////  ptitplots ///// ridge plots
# =============================================================================
# =============================================================================

import matplotlib
import ptitprince as pt

# plt.figure(dpi= 380)
col_list_ridge = ['year','Country','GCM','RCP','Yield']
ridgeplot_data= DSSAT[col_list_ridge]
ridgeplot_data['Country']=ridgeplot_data['Country'].str.strip().replace(dict(zip(["It", "US", "Ch"], ["Italy","United States","China"])),regex=True)
ridgeplot_data["Yield"]=ridgeplot_data["Yield"]/1000

ridgeplot_data['Period']="2070-2099"
ridgeplot_data['Period'][ridgeplot_data["year"] <2010]="1980-2009" 
ridgeplot_data['Period'][(ridgeplot_data["year"] >= 2010)&(ridgeplot_data["year"] <2040)]="2010-2039"
ridgeplot_data['Period'][(ridgeplot_data["year"] >= 2040)&(ridgeplot_data["year"] <2070)]="2040-2069" 
ridgeplot_data['RCP']=ridgeplot_data['RCP'].str.strip().replace(dict(zip(["R26", "R70", "R85"], ["RCP 2.6","RCP 7.0","RCP 8.5"])),regex=True)


# =============================================================================
# Generate the petit plot (histo+boxplots) for all GCMs
# =============================================================================

countries = ["United States","Italy","China"]

plt.rcParams['xtick.labelsize'] = 28
plt.rcParams['ytick.labelsize'] = 28
plt.rcParams["font.weight"] = 800   #tick lablels
plt.rcParams["axes.labelweight"] = 800 #labels
paltte_pt=['#7F3C8D','#11A579','#3969AC','#F2B701','#E73F74']

fig1, axes = plt.subplots(2, 2,figsize=[28.5,16.5],sharex=False,sharey=False,dpi=300)
fig1.subplots_adjust(hspace=0.1, wspace=0.10)
axs = axes.ravel()
r=0
for cnts in countries:
    if r<3:
        ridgeplot_data_i=ridgeplot_data.loc[ridgeplot_data['Country']==cnts]
        ridgeplot_data_i=ridgeplot_data_i.groupby(["year","RCP","GCM"]).agg("mean").reset_index(level=[0,1,2])
        
        axs[r]=pt.RainCloud(x = "RCP", y = "Yield", data = ridgeplot_data_i, palette = paltte_pt,hue="GCM", bw = 0.2,label='_nolegend_',
                         width_viol = 0.75, ax = axs[r], move=0,alpha = .4,trim = False, offset=.1, orient = 'v', pointplot = True)  
        axs[r].set_xlabel("")
        axs[r].get_legend().remove()
        axs[r].set_ylabel('Yield (t DM/ ha)',labelpad=15,fontsize=37)
        yfmt = ScalarFormatterForceFormat()
        axs[r].yaxis.set_major_formatter(yfmt)
        axs[r].annotate("("+string.ascii_uppercase[r]+") "+cnts, xy=(0.02, 0.03), xycoords='axes fraction',fontsize=38)
        plt.tight_layout()
    
    if r==2:
        ridgeplot_data_i=ridgeplot_data.groupby(["year","RCP","GCM"]).agg("mean").reset_index(level=[0,1,2])

        axs[3]=pt.RainCloud(x = "RCP", y = "Yield", data = ridgeplot_data_i, palette = paltte_pt,hue="GCM", bw = 0.2,label='_nolegend_',
                         width_viol = 0.75, ax =  axs[3], move=0,alpha = .4,trim = False, offset=.1, orient = 'v', pointplot = True)  
        axs[3].set_xlabel("")
      
        axs[3].set_ylabel('Yield (t DM/ ha)',labelpad=15,fontsize=32)
        yfmt = ScalarFormatterForceFormat()
        axs[3].yaxis.set_major_formatter(yfmt)
        axs[3].annotate("(D) Global (mean)" , xy=(0.02, 0.03), xycoords='axes fraction',fontsize=38)
        plt.tight_layout()
    h,l =  axs[2].get_legend_handles_labels()
    l=["GCM 1","GCM 2","GCM 3","GCM 4","GCM 5"]
    col_lgd = plt.legend(h[:5], l[:], loc='upper left', 
                         bbox_to_anchor=(-1.20, -0.10), fancybox=True, shadow=True, ncol=5,fontsize=40)
    r=r+1

# =============================================================================
# Generate the petit plot (histo+boxplots) for the man of GCMs
# =============================================================================

import string

paltte_pt=['#FFB14E', '#EA5F94','#0b9681']

fig1, axes = plt.subplots(2, 2,figsize=[28.5,16.5],sharex=False,sharey=False,dpi=300)
fig1.subplots_adjust(hspace=0.1, wspace=0.10)
axs = axes.ravel()
r=0
for cnts in countries:
    if r<3:
        ridgeplot_data_i=ridgeplot_data.loc[ridgeplot_data['Country']==cnts]
        ridgeplot_data_i=ridgeplot_data_i.groupby(["year","RCP"]).agg("mean").reset_index(level=[0,1])
        
        axs[r]=pt.RainCloud(x = "RCP", y = "Yield", data = ridgeplot_data_i, palette = paltte_pt, bw = 0.2,
                      width_viol = 0.75, ax = axs[r], move=0,trim = False, offset=.1, orient = 'v', pointplot = True)      
        axs[r].set_xlabel("")
        axs[r].set_ylabel('Yield (t DM/ ha)',labelpad=15,fontsize=37)
        yfmt = ScalarFormatterForceFormat()
        axs[r].yaxis.set_major_formatter(yfmt)
        axs[r].annotate("("+string.ascii_uppercase[r]+") "+cnts, xy=(0.02, 0.03), xycoords='axes fraction',fontsize=38)
        plt.tight_layout()
    
    if r==2:
        ridgeplot_data_i=ridgeplot_data.groupby(["year","RCP"]).agg("mean").reset_index(level=[0,1])
        axs[3]=pt.RainCloud(x = "RCP", y = "Yield", data = ridgeplot_data_i, palette = paltte_pt, bw = 0.2,
                      width_viol = 0.75, ax = axs[3], move=0,trim = False, offset=.1, orient = 'v', pointplot = True)
        axs[3].set_xlabel("")
        axs[3].set_ylabel('Yield (t DM/ ha)',labelpad=15,fontsize=32)
        yfmt = ScalarFormatterForceFormat()
        axs[3].yaxis.set_major_formatter(yfmt)
        axs[3].annotate("(D) Global (mean)" , xy=(0.02, 0.03), xycoords='axes fraction',fontsize=38)
        plt.tight_layout()
   
    r=r+1

# =============================================================================
# Heat Maps and box plots 
# =============================================================================


path=r"D:\Davide\extracted_comma delimited"    
all_files = glob.glob(path + "/*.csv")
# n=1
li=[]
for F in all_files:
    city=pd.read_csv(F)
    li.append(city)
    
df = pd.concat(li, axis=0, ignore_index=True)
df['City']=df['City'].str.strip().replace(dict(zip(["Inner"], ["Inner Mongolia"])),regex=True)
df['TMEAN']=(df['TMAX']+df['TMIN'])/2
df['Period']="2070-2099"
df['Period'][df["year"] <2010]="1980-2009" 
df['Period'][(df["year"] >= 2010)&(df["year"] <2040)]="2010-2039"
df['Period'][(df["year"] >= 2040)&(df["year"] <2070)]="2040-2069" 

df_1=df[~df['RCP'].str.contains("Baseline")]
df_1=df_1[df_1['year'] != 2100]
df_1=df_1.reset_index(drop=True)

df_1['GCM']=df_1['GCM'].str.strip().replace(dict(zip(["GFDL-ESM4", "IPSL-CM6A-LR", "MPI-ESM1-2-HR", "MRI-ESM2-0","UKESM1-0-LL"], ["1","2","3","4","5"])),regex=True)
# df.columns
df_1['RCP']=df_1['RCP'].str.strip().replace(dict(zip(["SSP1-RCP2.6", "SSP3-RCP7.0", "SSP5-RCP8.5"], ["R26","R70","R85"])),regex=True)
col_list_raw = ['year','Period','Country','City','GCM','RCP','TMEAN','TMAX','Lat','TMIN',"RAIN"]
raw_data=df_1[col_list_raw]

raw_data['Lat']=raw_data['Lat'].astype(int)
variables=["TMAX","TMIN","RAIN","TMEAN"]

for var in variables: 
    sns.set(style="ticks", palette='Set2')
    sns.set_context("paper", font_scale=2, rc={"lines.linewidth": 1.1})
    plt.rcParams['xtick.labelsize'] = 20
    plt.rcParams['ytick.labelsize'] = 20
    g=sns.catplot(x=var, y="Lat",
                   col="Period", data=raw_data,hue='City', kind="violin", orient="h", height=4, aspect=24/18,
                     width=2,fliersize=2.5,showfliers=False, col_wrap=2,linewidth=0.1,legend=False,
                     notch=False)
    # plt.legend(False)
    g.fig.axes[0].invert_yaxis()

    # sns.despine(trim=True)
    # for ax in g.axes.flatten():
        # ax.xlabel('Days to Resolve', fontsize=15)
    g.set_ylabels("")

    if var=="TMAX":
        g.set_xlabels("")
        plt.text(-40,-3.2, 'Maxmium air temperature ($^oC$)', va='center', size=30)
        plt.text(-65.8,12, 'Latitude', va='center', rotation='vertical', size=30)

    if var=="TMIN":
        g.set_xlabels("")
        plt.text(-55,-3.2, 'Minimum air temperature ($^oC$)', va='center', size=30)
        plt.text(-80.8,12, 'Latitude', va='center', rotation='vertical', size=30)

    if var=="RAIN":
        g.set_xlabels('')
        plt.text(-900,-3.2, 'Precipitation (mm)', va='center', size=30)
        plt.text(-2000,12, 'Latitude', va='center', rotation='vertical', size=30)

    if var=="TMEAN":
        g.set_xlabels("")
        plt.text(-40,-3.2, 'Mean air temperature ($^oC$)', va='center', size=30)
        plt.text(-70.8,12, 'Latitude', va='center', rotation='vertical', size=30)

    plt.legend(bbox_to_anchor=(-0.08, -0.7),frameon=False,
                 loc="lower center",markerscale=1,ncol=3,borderpad=0.1,labelspacing=.7,handletextpad=0.5,fontsize=20)
    
    g.savefig(r"D:\Davide\Figures\Lat_"+var+".jpg",optimize = True, quality=95, transparent=True, dpi=300)

# =============================================================================
# # =============================================================================
# The box plots for Percentage change (cities in individual plots)
# =============================================================================
# =============================================================================

df=df[df['year'] != 1979]
df=df[df['year'] != 2100]

VARIABLES=['RAIN','TMAX','TMIN','TMEAN']

for VARIABLE in VARIABLES:
    label_common="Air temperature change from the baseline ($^oC$)"
    if VARIABLE=="RAIN":
        label="Precipitation change from the baseline (%)"
        label_common=label
        heading=""
        vmin=-110 ; vmax=110
        txt1=-0.84; txt2=250
        q1=0.99;q2=0.07
        changevar="RainChange"
        cmp='plasma_r'
        heatmap_vmin=-55
        heatmap_vmax=35
    if VARIABLE=="TMAX":
        label="Maximum air temperature change from the baseline ($^oC$)"
        heading="Maximum air temperature"
        txt1=-0.89; txt2=19
        q1=0.995;q2=0.0001
        vmin=-5; vmax=10
        changevar="TmaxChange"
        cmp="Reds"
        heatmap_vmin=0
        heatmap_vmax=7.5     
    if VARIABLE=="TMIN":
        label="Minimum air temperature change from the baseline ($^oC$)"
        heading="Minimum air temperature"

        txt1=-0.89; txt2=19
        q1=0.99;q2=0.0001
        vmin=-5; vmax=10
        changevar="TminChange"
        cmp="Reds"
        heatmap_vmin=0
        heatmap_vmax=7.5 
    if VARIABLE=="TMEAN":
        label="Mean air temperature change from the baseline ($^oC$)"
        heading="Mean air temperature"

        txt1=-0.89; txt2=20
        q1=0.99;q2=0.0001
        vmin=-5; vmax=10
        changevar="TmeanChange"
        cmp="Reds"
        heatmap_vmin=0
        heatmap_vmax=7.5     
    
    baseline=df.loc[df['RCP']=="Baseline"]
    baseline=baseline[baseline['year'] < 2010]
    allcities=list(set(df["City"].tolist()))
    
    
    plt.rcParams['axes.titleweight']='bold'
    plt.rcParams['xtick.labelsize'] = 26
    plt.rcParams['ytick.labelsize'] = 22
    plt.rcParams["font.weight"] = "bold"
    plt.rcParams['axes.labelsize'] = 20
    plt.rcParams['axes.labelsize'] = 30
    plt.rcParams["font.weight"] = "bold"
    plt.rcParams["axes.labelweight"] = "bold"
    
    # =============================================================================
    # The box plots for Percentage change (cities in one plot)
    # =============================================================================
    
    PRCP=[]
    PRCP2=[]
    
    for City in allcities:
        baseline_city=baseline.loc[baseline['City']==City]
        
        Rain_base=baseline.groupby(['Lat','Lon']).agg('mean').reset_index(level=[0,1])
        Rain_base=Rain_base[['Lat','Lon',VARIABLE]].copy()
        Rain_base.columns = ['Lat','Lon', 'data_avg']
        Rain_base["Lat"]=Rain_base["Lat"].map('{:.2f}'.format)
        Rain_base["Lon"]=Rain_base["Lon"].map('{:.2f}'.format)
        Rain_base["new"] = Rain_base["Lat"].astype(str)+"_"+Rain_base["Lon"].astype(str)
     
        Pch=df_1.loc[df_1['City']==City]
        
        Pch["Lat"]=Pch["Lat"].map('{:.2f}'.format)
        Pch["Lon"]=Pch["Lon"].map('{:.2f}'.format)
        Pch["new"] = Pch["Lat"].astype(str)+"_"+Pch["Lon"].astype(str)
        Rain_base.dtypes
        withlatlon=Pch.merge(Rain_base,left_on=['Lat','Lon'],right_on=['Lat','Lon'])
    
        withlatlon['RainChange']=100*(withlatlon["RAIN"]-withlatlon["data_avg"])/withlatlon["RAIN"]
    
        withlatlon['TmaxChange']=(withlatlon["TMAX"]-withlatlon["data_avg"])
        withlatlon['TminChange']=(withlatlon["TMIN"]-withlatlon["data_avg"])
        withlatlon['TmeanChange']=(withlatlon["TMEAN"]-withlatlon["data_avg"])
           
        PRCP2.append(withlatlon)
    
    Pch_all = pd.concat(PRCP2, axis=0, ignore_index=True)
    Pch_all['RCP']=Pch_all['RCP'].str.strip().replace(dict(zip(["R26", "R70", "R85"], ["SSP1-2.6","SSP3-7.0","SSP5-8.5"])),regex=True)
    # rows = ["SSP1-2.6","SSP3-7.0","SSP5-8.5"]
    allGCMs=list(set(df_1["GCM"].tolist()))
    for GCM in allGCMs:
        Pch_all_GCM=Pch_all.loc[Pch_all['GCM']==GCM]
        
        q_h=Pch_all_GCM[changevar].quantile(q1)
        q_l=Pch_all_GCM[changevar].quantile(q2)
        Pch_all_GCM=Pch_all_GCM[(Pch_all_GCM[changevar] < q_h) & (Pch_all_GCM[changevar] > q_l)]
                
        Pch_P1=Pch_all_GCM.loc[Pch_all_GCM['Period']=="2010-2039"]
        Pch_P2=Pch_all_GCM.loc[Pch_all_GCM['Period']=="2040-2069"]
        Pch_P3=Pch_all_GCM.loc[Pch_all_GCM['Period']=="2070-2099"]
        Pch_dataframes=[Pch_P1,Pch_P2,Pch_P3]
        
        plt.rcParams['axes.titleweight']='bold'
        plt.rcParams['xtick.labelsize'] = 40
        plt.rcParams['ytick.labelsize'] = 36
        plt.rcParams["font.weight"] = 800
        plt.rcParams["font.size"] = 36
        plt.rcParams["axes.labelweight"] = "bold"
        
    # =============================================================================
    # # for all GCMs
    # =============================================================================
    q_h=Pch_all[changevar].quantile(q1)
    q_l=Pch_all[changevar].quantile(q2)  
    Pch_all_GCM=Pch_all[(Pch_all[changevar] < q_h) & (Pch_all[changevar] > q_l)]
    
    
    Pch_P1=Pch_all_GCM.loc[Pch_all_GCM['Period']=="2010-2039"]
    Pch_P2=Pch_all_GCM.loc[Pch_all_GCM['Period']=="2040-2069"]
    Pch_P3=Pch_all_GCM.loc[Pch_all_GCM['Period']=="2070-2099"]
    Pch_dataframes=[Pch_P1,Pch_P2,Pch_P3]
    
    plt.rcParams['axes.titleweight']='bold'
    plt.rcParams['xtick.labelsize'] = 40
    plt.rcParams['ytick.labelsize'] = 36
    plt.rcParams["font.weight"] = "bold"
    plt.rcParams["font.size"] = 36
    plt.rcParams["axes.labelweight"] = "bold"
    plt.rcParams["font.weight"] = 800
    fig1, axes = plt.subplots(3, 1,figsize=[23,26],sharex=True,sharey=True,dpi=100)
    cols = ["Time window: 2010-2039","Time window: 2040-2069","Time window: 2070-2099"]
    pad = 5 # in points

    axes[0].annotate(heading, xy=(0.5, 1), xytext=(0, pad),
                    xycoords='axes fraction', textcoords='offset points',
                    size=50, ha='center', va='baseline')
    bottom, top = 0.1, 0.9
    left, right = 0.1, 0.8
    fig1.subplots_adjust(hspace=.1)
    
    n=0
    for d, dataframe in enumerate (Pch_dataframes):
        sns.boxplot(x="RCP", y=changevar,hue='City',
                       data=dataframe, ax=axes[d],
                         width=0.6,fliersize=1.5,showfliers=False, linewidth=1.5,
                         notch=False,orient="v",palette=(sns.color_palette(['#009392','#39b185','#9ccb86','#e9e29c','#eeb479','#e88471','#cf597e'])))
        axes[d].axhline(0, ls='--',color="red",linewidth=3)
        axes[d].set_xlabel("")
        axes[d].set_ylabel("")
    
        axes[d].set_ylim(vmin,vmax)
       
        axes[d].get_legend().remove()
        axes[d].annotate(cols[d], xy=(0.03, 0.90), xycoords='axes fraction',fontsize=35)

        if d==2:
            axes[d].legend(loc='lower center', bbox_to_anchor=(0.5, -.6),
              ncol=3, fancybox=True, shadow=True,fontsize=42)
        plt.rc('grid', linestyle="--", color='grey')
        axes[d].grid(axis='y')
    plt.text(txt1,txt2, label_common, va='center', rotation='vertical', size=50)
    
    plt.show()    
    
    # =============================================================================
    # # heatmap for percentage change as timeseries
    # =============================================================================
      
    Pch_dataframes=[Pch_P1,Pch_P2,Pch_P3]
    a=[];b=[]
    for d, dataframe in enumerate (Pch_dataframes):
        dataframe["latlon"]=dataframe["Lat"]+dataframe["Lon"]
        Pch_R1=dataframe.loc[dataframe['RCP']=="SSP1-2.6"]
        Pch_R2=dataframe.loc[dataframe['RCP']=="SSP3-7.0"]
        Pch_R3=dataframe.loc[dataframe['RCP']=="SSP5-8.5"]    
        Pch_R_dataframes=[Pch_R1,Pch_R2,Pch_R3]
        fig, axes = plt.subplots(3, 1,figsize=[23,20],sharex=True,sharey=True,dpi=100)
        fig.subplots_adjust(hspace=0.05)
        for R, dataframe_R in enumerate (Pch_R_dataframes):
            RCP_data_grp=dataframe_R.groupby(['year','City']).agg('mean').reset_index(level=[0,1])
            RCP_data_grp['year']=RCP_data_grp['year'].astype(int)
            a.append(RCP_data_grp[changevar].min())#### this is for selecting vmin -> once i run this part to get the min and max of all heatmaps then I set vmin and vmax in snsheatmeap
            b.append(RCP_data_grp[changevar].max())

            RCP_data_pvt=RCP_data_grp.pivot_table(changevar, ['City'], 'year')
            im=sns.heatmap(RCP_data_pvt,linewidth=1, linecolor='w', cmap=cmp, cbar = False,vmin=heatmap_vmin,vmax=heatmap_vmax, square=True,ax=axes[R],annot_kws={"fontsize":18,"weight": "bold"})
            axes[R].set_ylabel("")
            axes[R].set_xlabel("")
    
        mappable = im.get_children()[0]
        plt.colorbar(mappable, ax = [axes[0],axes[1],axes[2]],pad=0.12,orientation = 'horizontal')


# =============================================================================
# =============================================================================
# # Generating the spatial plots 
# =============================================================================
# =============================================================================

import numpy as np
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
import contextily as ctx
import geopandas as gpd
import matplotlib as mpl


Shp= gpd.read_file(r"\clf.shp")

DSSAT.columns
col_list_YTR= ['TRT','year','GIS_code','Country','CNCT','GCM','RCP','Yield','PRCP','TMAX','TMIN']
gis_data=pd.read_csv('\GIS-Jan2021.csv')
YTR= DSSAT[col_list_YTR]
YTR=YTR.merge(gis_data, how='left', left_on=['TRT'], right_on=['TRT'])
YTR['Tmean']=(YTR['TMAX']+YTR['TMIN'])/2
YTR['City']=YTR['CNCT'].str[2:]
YTR['Yield']=YTR['Yield']/1000
YTR['City']=YTR['City'].str.strip().replace(dict(zip(["Ca","Em" ,"Fo", "In", "Ga","Xi"], ["California","Emilia","Foggia","Inner Mongolia","Gansu","Xinjiang"])),regex=True)
YTR['Period']="2070-2099"
YTR['Period'][YTR["year"] <2010]="1980-2009" 
YTR['Period'][(YTR["year"] >= 2010)&(YTR["year"] <2040)]="2010-2039"
YTR['Period'][(YTR["year"] >= 2040)&(YTR["year"] <2070)]="2040-2069" 
Spatial_city=YTR.loc[YTR['City']=="California"]
# Pch_all
Spatial_city=Pch_all.loc[Pch_all['City']=="California"]

Spatial_city['RCP']=Spatial_city['RCP'].str.strip().replace(dict(zip(["R26", "R70", "R85"], ["RCP 2.6","RCP 7.0","RCP 8.5"])),regex=True)

RCPs=["RCP 2.6","RCP 7.0","RCP 8.5"]
Periods=["1980-2009","2010-2039","2040-2069","2070-2099"]

fig1, axes = plt.subplots(3, 4,figsize=[37,30],sharex=True,sharey=True,dpi=100)
cols = Periods
rows = RCPs
pad = 20 # in points
for ax, col in zip(axes[0], cols):
    ax.annotate(col, xy=(0.5, 1), xytext=(0, pad),
                xycoords='axes fraction', textcoords='offset points',
                size=55, ha='center', va='baseline')

for ax, row in zip(axes[:,0], rows):
    ax.annotate(row, xy=(0, 0.5), xytext=(-ax.yaxis.labelpad - pad, 0),
                xycoords=ax.yaxis.label, textcoords='offset points',
                size=55, ha='right', va='center',rotation=90)


fig1.subplots_adjust( hspace=0.05, wspace=0.05)
a=[];b=[]
n=0
for rcps in RCPs:
    Spatial_city_rcp=Spatial_city.loc[Spatial_city['RCP']==rcps]
    d=0
    for prd in Periods:
        Spatial_city_rcp_P_1=Spatial_city_rcp.loc[Spatial_city_rcp['Period']==prd]
        PG_1=Spatial_city_rcp_P_1.groupby(['Lon','Lat']).agg(['mean', 'std']).reset_index(level=[0,1])
        PG_1.columns = PG_1.columns.get_level_values(0)+PG_1.columns.get_level_values(1)
        PG_1_gpd = gpd.GeoDataFrame(PG_1, geometry=gpd.points_from_xy(PG_1['Lon'], PG_1['Lat']), crs='EPSG:4326')   
        plt.rcParams['xtick.labelsize'] = 50
        plt.rcParams['ytick.labelsize'] = 50
        axes[n,d] = PG_1_gpd.to_crs('EPSG:4326').plot(column='RAINmean', 
                          cmap='Reds', 
                          linewidth=0.9,s=300, marker="s",
                          ax=axes[n,d], 
                          edgecolor='black', 
                          alpha=0.8)
        # add basemap
        ctx.add_basemap(axes[n,d], crs=  'EPSG:4326'  ,    attribution=False, 
                        source=ctx.providers.Esri.WorldImagery)
        # Remove axis labels
        # California.to_crs('EPSG:4326').plot(figsize=(9, 9),edgecolor="black",color='None',ax=ax2)
        Shp.plot(edgecolor="yellow",color='None',linewidth=2,ax=axes[n,d])
        axes[n,d].set_xlabel('')
        axes[n,d].set_ylabel('')
        if n==2:
            axes[n,d].set_xlabel('longitude',fontsize=50)
        if d==0:
            axes[n,d].set_ylabel('Latitude',fontsize=50)
        plt.xlim(-123,-118.2)
        plt.ylim( 34.5,41)
        
        d=d+1
    n=n+1

cax,kw = mpl.colorbar.make_axes([ax for ax in axes.flat],label='Average simulated yield (t DM/ ha)')
cbar=plt.colorbar(axes[0][0].get_children()[0], pad=0.01,cax=cax, **kw)
cbar.set_label( label='Average simulated yield (t DM/ ha)',size=55, weight='bold')
cbar.ax.tick_params(labelsize=45,pad=20) 

# =============================================================================
# =============================================================================
# # # XinJiang
# =============================================================================
# =============================================================================

Shp= gpd.read_file(r"Xinjiang.shp")

Spatial_city=YTR.loc[YTR['City']=="Xinjiang"]
Spatial_city['RCP']=Spatial_city['RCP'].str.strip().replace(dict(zip(["R26", "R70", "R85"], ["RCP 2.6","RCP 7.0","RCP 8.5"])),regex=True)

fig1, axes = plt.subplots(3, 4,figsize=[40,17],sharex=True,sharey=True,dpi=100)
cols = Periods
rows = RCPs
pad = 20 # in points
for ax, col in zip(axes[0], cols):
    ax.annotate(col, xy=(0.5, 1), xytext=(0, pad),
                xycoords='axes fraction', textcoords='offset points',
                size=40, ha='center', va='baseline')

for ax, row in zip(axes[:,0], rows):
    ax.annotate(row, xy=(0, 0.5), xytext=(-ax.yaxis.labelpad - pad, 0),
                xycoords=ax.yaxis.label, textcoords='offset points',
                size=40, ha='right', va='center',rotation=90)


fig1.subplots_adjust( hspace=0.05, wspace=0.05)
# plt.setp(ax, xlim=(10, 35))
# plt.setp(ax, ylim=(0, 11))
a=[];b=[]
n=0
for rcps in RCPs:
    Spatial_city_rcp=Spatial_city.loc[Spatial_city['RCP']==rcps]
    d=0
    for prd in Periods:
        Spatial_city_rcp_P_1=Spatial_city_rcp.loc[Spatial_city_rcp['Period']==prd]
        PG_1=Spatial_city_rcp_P_1.groupby(['LON','LAT']).agg(['mean', 'std']).reset_index(level=[0,1])
        PG_1.columns = PG_1.columns.get_level_values(0)+PG_1.columns.get_level_values(1)
        PG_1_gpd = gpd.GeoDataFrame(PG_1, geometry=gpd.points_from_xy(PG_1['LON'], PG_1['LAT']), crs='EPSG:4326')
        a.append(PG_1_gpd["Yieldmean"].min())#### this is for selecting vmin -> once i run this part to get the min and max of all heatmaps then I set vmin and vmax in snsheatmeap
        b.append(PG_1_gpd["Yieldmean"].max())
        
        plt.rcParams['xtick.labelsize'] = 35
        plt.rcParams['ytick.labelsize'] = 35

        axes[n,d] = PG_1_gpd.to_crs('EPSG:4326').plot(column='Yieldmean', 
                          cmap='Reds', vmin=0.5,vmax=7,marker='s',
                          linewidth=0.9,s=600, 
                          ax=axes[n,d], 
                          edgecolor='black', 
                          alpha=0.8)
        Shp.plot(edgecolor="yellow",color='None',linewidth=3,ax=axes[n,d])
        ctx.add_basemap(axes[n,d],zoom='auto',   attribution=False,crs=  'EPSG:4326'  ,    
                        source=ctx.providers.Esri.WorldImagery)
        axes[n,d].set_xlabel('')
        axes[n,d].set_ylabel('')
        if n==2:
            axes[n,d].set_xlabel('longitude',fontsize=35)
        if d==0:
            axes[n,d].set_ylabel('Latitude',fontsize=35)

        # ax.set_xticks([])
        # ax.set_yticks([])
        plt.xlim(79,90)
        plt.ylim( 39.5,46.5)
        
        d=d+1
    n=n+1

cax,kw = mpl.colorbar.make_axes([ax for ax in axes.flat],label='Average simulated yield (t DM/ ha)')
cbar=plt.colorbar(axes[0][0].get_children()[0], pad=0.01,cax=cax, **kw)
cbar.set_label( label='Average simulated yield (t DM/ ha)',size=40,labelpad=20, weight='bold')
cbar.ax.tick_params(labelsize=35,pad=10) 

# =============================================================================
# =============================================================================
# # # # Inner Mongolia
# =============================================================================
# =============================================================================

Shp= gpd.read_file(r"\inner.shp")

Spatial_city=YTR.loc[YTR['City']=="Inner Mongolia"]
Spatial_city['RCP']=Spatial_city['RCP'].str.strip().replace(dict(zip(["R26", "R70", "R85"], ["RCP 2.6","RCP 7.0","RCP 8.5"])),regex=True)

fig1, axes = plt.subplots(3, 4,figsize=[40,20],sharex=True,sharey=True,dpi=100)
cols = Periods
rows = RCPs
pad = 20 # in points
for ax, col in zip(axes[0], cols):
    ax.annotate(col, xy=(0.5, 1), xytext=(0, pad),
                xycoords='axes fraction', textcoords='offset points',
                size=40, ha='center', va='baseline')

for ax, row in zip(axes[:,0], rows):
    ax.annotate(row, xy=(0, 0.5), xytext=(-ax.yaxis.labelpad - pad, 0),
                xycoords=ax.yaxis.label, textcoords='offset points',
                size=40, ha='right', va='center',rotation=90)


fig1.subplots_adjust( hspace=0.05, wspace=0.05)
# plt.setp(ax, xlim=(10, 35))
# plt.setp(ax, ylim=(0, 11))
a=[];b=[]
n=0
for rcps in RCPs:
    Spatial_city_rcp=Spatial_city.loc[Spatial_city['RCP']==rcps]
    d=0
    for prd in Periods:
        Spatial_city_rcp_P_1=Spatial_city_rcp.loc[Spatial_city_rcp['Period']==prd]
        PG_1=Spatial_city_rcp_P_1.groupby(['LON','LAT']).agg(['mean', 'std']).reset_index(level=[0,1])
        PG_1.columns = PG_1.columns.get_level_values(0)+PG_1.columns.get_level_values(1)
        PG_1_gpd = gpd.GeoDataFrame(PG_1, geometry=gpd.points_from_xy(PG_1['LON'], PG_1['LAT']), crs='EPSG:4326')
        a.append(PG_1_gpd["Yieldmean"].min())#### this is for selecting vmin -> once i run this part to get the min and max of all heatmaps then I set vmin and vmax in snsheatmeap
        b.append(PG_1_gpd["Yieldmean"].max())
        
        plt.rcParams['xtick.labelsize'] = 35
        plt.rcParams['ytick.labelsize'] = 35
        axes[n,d] = PG_1_gpd.to_crs('EPSG:4326').plot(column='Yieldmean', 
                          cmap='Reds', vmin=2.5,vmax=6,marker='s',
                          linewidth=0.9,s=500, 
                          ax=axes[n,d], 
                          edgecolor='black', 
                          alpha=0.8)
        # add basemap
        ctx.add_basemap(axes[n,d],zoom='auto',   attribution=False,crs=  'EPSG:4326'  ,    
                        source=ctx.providers.Esri.WorldImagery)
        # Remove axis labels
        # California.to_crs('EPSG:4326').plot(figsize=(9, 9),edgecolor="black",color='None',ax=ax2)
        Shp.plot(edgecolor="yellow",color='None',linewidth=3,ax=axes[n,d])
        axes[n,d].set_xlabel('')
        axes[n,d].set_ylabel('')
        if n==2:
            axes[n,d].set_xlabel('longitude',fontsize=35)
        if d==0:
            axes[n,d].set_ylabel('Latitude',fontsize=35)
        d=d+1
    n=n+1

cax,kw = mpl.colorbar.make_axes([ax for ax in axes.flat],label='Average simulated yield (t DM/ ha)')
cbar=plt.colorbar(axes[0][0].get_children()[0], pad=0.01,cax=cax, **kw)
cbar.set_label( label='Average simulated yield (t DM/ ha)',size=40, weight='bold')
cbar.ax.tick_params(labelsize=35,pad=10) 
# =============================================================================
# =============================================================================
# # # Gansu
# =============================================================================
# =============================================================================

Shp= gpd.read_file(r"\Gansu.shp")

Spatial_city=YTR.loc[YTR['City']=="Gansu"]
Spatial_city['RCP']=Spatial_city['RCP'].str.strip().replace(dict(zip(["R26", "R70", "R85"], ["RCP 2.6","RCP 7.0","RCP 8.5"])),regex=True)

fig1, axes = plt.subplots(3, 4,figsize=[30,10],sharex=True,sharey=True,dpi=100)
cols = Periods
rows = RCPs
pad = 20 # in points
for ax, col in zip(axes[0], cols):
    ax.annotate(col, xy=(0.5, 1), xytext=(0, pad),
                xycoords='axes fraction', textcoords='offset points',
                size=30, ha='center', va='baseline')

for ax, row in zip(axes[:,0], rows):
    ax.annotate(row, xy=(0, 0.5), xytext=(-ax.yaxis.labelpad - pad, 0),
                xycoords=ax.yaxis.label, textcoords='offset points',
                size=30, ha='right', va='center',rotation=90)


fig1.subplots_adjust( hspace=0.05, wspace=0.05)
a=[];b=[]
n=0
for rcps in RCPs:
    Spatial_city_rcp=Spatial_city.loc[Spatial_city['RCP']==rcps]
    d=0
    for prd in Periods:
        Spatial_city_rcp_P_1=Spatial_city_rcp.loc[Spatial_city_rcp['Period']==prd]
        PG_1=Spatial_city_rcp_P_1.groupby(['LON','LAT']).agg(['mean', 'std']).reset_index(level=[0,1])
        PG_1.columns = PG_1.columns.get_level_values(0)+PG_1.columns.get_level_values(1)
        PG_1_gpd = gpd.GeoDataFrame(PG_1, geometry=gpd.points_from_xy(PG_1['LON'], PG_1['LAT']), crs='EPSG:4326')
        a.append(PG_1_gpd["Yieldmean"].min())#### this is for selecting vmin -> once i run this part to get the min and max of all heatmaps then I set vmin and vmax in snsheatmeap
        b.append(PG_1_gpd["Yieldmean"].max())
        
        plt.rcParams['xtick.labelsize'] = 26
        plt.rcParams['ytick.labelsize'] = 26
        axes[n,d] = PG_1_gpd.to_crs('EPSG:4326').plot(column='Yieldmean', 
                          cmap='Reds', vmin=2.5,vmax=6,
                          linewidth=0.9,s=500, marker='s',
                          ax=axes[n,d], 
                          edgecolor='black', 
                          alpha=0.8)
        # add basemap
        Shp.plot(edgecolor="yellow",color='None',linewidth=3,ax=axes[n,d])

        ctx.add_basemap(axes[n,d],zoom='auto',   attribution=False,crs=  'EPSG:4326'  ,    
                        source=ctx.providers.Esri.WorldImagery)
        # Remove axis labels
        # California.to_crs('EPSG:4326').plot(figsize=(9, 9),edgecolor="black",color='None',ax=ax2)
        axes[n,d].set_xlabel('')
        axes[n,d].set_ylabel('')
        if n==2:
            axes[n,d].set_xlabel('longitude',fontsize=26)
        if d==0:
            axes[n,d].set_ylabel('Latitude',fontsize=26)

        # ax.set_xticks([])
        # ax.set_yticks([])
        axes[n,d].set_xlim(99.5,103.5)
        axes[n,d].set_ylim( 37.5,39.5)
        
        d=d+1
    n=n+1

cax,kw = mpl.colorbar.make_axes([ax for ax in axes.flat],label='Average simulated yield (t DM/ ha)')
cbar=plt.colorbar(axes[0][0].get_children()[0],cax=cax, **kw)
cbar.set_label( label='Average simulated yield (t DM/ ha)',size=30, labelpad=20,weight='bold')
cbar.ax.tick_params(labelsize=24) 
# =============================================================================
# =============================================================================
# # # Emilia
# =============================================================================
# =============================================================================
Shp= gpd.read_file(r"\Emilia.shp")

Spatial_city=YTR.loc[YTR['City']=="Emilia"]
Spatial_city['RCP']=Spatial_city['RCP'].str.strip().replace(dict(zip(["R26", "R70", "R85"], ["RCP 2.6","RCP 7.0","RCP 8.5"])),regex=True)

fig1, axes = plt.subplots(3, 4,figsize=[42,15],sharex=True,sharey=True,dpi=120)
cols = Periods
rows = RCPs
pad = 20 # in points
for ax, col in zip(axes[0], cols):
    ax.annotate(col, xy=(0.5, 1), xytext=(0, pad),
                xycoords='axes fraction', textcoords='offset points',
                size=45, ha='center', va='baseline')

for ax, row in zip(axes[:,0], rows):
    ax.annotate(row, xy=(0, 0.5), xytext=(-ax.yaxis.labelpad - pad, 0),
                xycoords=ax.yaxis.label, textcoords='offset points',
                size=48, ha='right', va='center',rotation=90)


fig1.subplots_adjust( hspace=0.1, wspace=0.1)
# plt.setp(ax, xlim=(10, 35))
# plt.setp(ax, ylim=(0, 11))
a=[];b=[]
n=0
for rcps in RCPs:
    Spatial_city_rcp=Spatial_city.loc[Spatial_city['RCP']==rcps]
    d=0
    for prd in Periods:
        Spatial_city_rcp_P_1=Spatial_city_rcp.loc[Spatial_city_rcp['Period']==prd]
        PG_1=Spatial_city_rcp_P_1.groupby(['LON','LAT']).agg(['mean', 'std']).reset_index(level=[0,1])
        PG_1.columns = PG_1.columns.get_level_values(0)+PG_1.columns.get_level_values(1)
        PG_1_gpd = gpd.GeoDataFrame(PG_1, geometry=gpd.points_from_xy(PG_1['LON'], PG_1['LAT']), crs='EPSG:4326')
        a.append(PG_1_gpd["Yieldmean"].min())#### this is for selecting vmin -> once i run this part to get the min and max of all heatmaps then I set vmin and vmax in snsheatmeap
        b.append(PG_1_gpd["Yieldmean"].max())
        
        plt.rcParams['xtick.labelsize'] = 10
        plt.rcParams['ytick.labelsize'] = 10
        # add basemap

       
        axes[n,d] = PG_1_gpd.to_crs('EPSG:4326').plot(column='Yieldmean', 
                          cmap='Reds', vmin=2.5,vmax=8,
                          linewidth=0.9,s=600, marker='s',
                          ax=axes[n,d], 
                          edgecolor='black', 
                          alpha=0.8,zorder=10)
        Shp.plot(edgecolor="yellow",color='None',linewidth=2,ax=axes[n,d],zorder=5)

        ctx.add_basemap(axes[n,d],attribution=False,zoom=7,crs=  'EPSG:4326'  ,    
                        source=ctx.providers.Esri.WorldImagery,zorder=0)
        # Remove axis labels
        # California.to_crs('EPSG:4326').plot(figsize=(9, 9),edgecolor="black",color='None',ax=ax2)
        axes[n,d].set_xlabel('')
        axes[n,d].set_ylabel('')
        if n==2:
            axes[n,d].set_xlabel('longitude',fontsize=30)
        if d==0:
            axes[n,d].set_ylabel('Latitude',fontsize=30)

        # ax.set_xticks([])
        # ax.set_yticks([])
        axes[n,d].set_xlim(9,13)
        axes[n,d].set_ylim( 44,46)
        
        d=d+1
    n=n+1

cax,kw = mpl.colorbar.make_axes([ax for ax in axes.flat],label='Average simulated yield (t DM/ ha)')
cbar=plt.colorbar(axes[0][0].get_children()[0],cax=cax, **kw)
cbar.set_label( label='Average simulated yield (t DM/ ha)',size=45, labelpad=20,weight='bold')
cbar.ax.tick_params(labelsize=40) 
# =============================================================================
# =============================================================================
# # # Foggia
# =============================================================================
# =============================================================================
Shp= gpd.read_file(r"\Fogia.shp")
Spatial_city=YTR.loc[YTR['City']=="Foggia"]
Spatial_city['RCP']=Spatial_city['RCP'].str.strip().replace(dict(zip(["R26", "R70", "R85"], ["RCP 2.6","RCP 7.0","RCP 8.5"])),regex=True)


fig1, axes = plt.subplots(3, 4,figsize=[25,14],sharex=True,sharey=True,dpi=120)
cols = Periods
rows = RCPs
pad = 20 # in points
for ax, col in zip(axes[0], cols):
    ax.annotate(col, xy=(0.5, 1), xytext=(0, pad),
                xycoords='axes fraction', textcoords='offset points',
                size=30, ha='center', va='baseline')

for ax, row in zip(axes[:,0], rows):
    ax.annotate(row, xy=(0, 0.5), xytext=(-ax.yaxis.labelpad - pad, 0),
                xycoords=ax.yaxis.label, textcoords='offset points',
                size=30, ha='right', va='center',rotation=90)


fig1.subplots_adjust( hspace=0.01, wspace=0.05)
# plt.setp(ax, xlim=(10, 35))
# plt.setp(ax, ylim=(0, 11))
a=[];b=[]
n=0
for rcps in RCPs:
    Spatial_city_rcp=Spatial_city.loc[Spatial_city['RCP']==rcps]
    d=0
    for prd in Periods:
        Spatial_city_rcp_P_1=Spatial_city_rcp.loc[Spatial_city_rcp['Period']==prd]
        PG_1=Spatial_city_rcp_P_1.groupby(['LON','LAT']).agg(['mean', 'std']).reset_index(level=[0,1])
        PG_1.columns = PG_1.columns.get_level_values(0)+PG_1.columns.get_level_values(1)
        PG_1_gpd = gpd.GeoDataFrame(PG_1, geometry=gpd.points_from_xy(PG_1['LON'], PG_1['LAT']), crs='EPSG:4326')
        a.append(PG_1_gpd["Yieldmean"].min())#### this is for selecting vmin -> once i run this part to get the min and max of all heatmaps then I set vmin and vmax in snsheatmeap
        b.append(PG_1_gpd["Yieldmean"].max())
        
        plt.rcParams['xtick.labelsize'] = 24
        plt.rcParams['ytick.labelsize'] = 24
        axes[n,d] = PG_1_gpd.to_crs('EPSG:4326').plot(column='Yieldmean', 
                          cmap='Reds', vmin=1.5,vmax=7,marker='s',
                          linewidth=0.9,s=250, 
                          ax=axes[n,d], 
                          edgecolor='black', 
                          alpha=0.8,zorder=3)
        # add basemap
        Shp.plot(edgecolor="yellow",color='None',linewidth=2,ax=axes[n,d],zorder=1)

        ctx.add_basemap(axes[n,d],zoom=7,  attribution=False,crs=  'EPSG:4326'  ,    
                        source=ctx.providers.Esri.WorldImagery)

        # Remove axis labels
        # California.to_crs('EPSG:4326').plot(figsize=(9, 9),edgecolor="black",color='None',ax=ax2)
        axes[n,d].set_xlabel('')
        axes[n,d].set_ylabel('')
        if n==2:
            axes[n,d].set_xlabel('longitude',fontsize=26)
        if d==0:
            axes[n,d].set_ylabel('Latitude',fontsize=26)

        # ax.set_xticks([])
        # ax.set_yticks([])
        axes[n,d].set_xlim(14,17)
        axes[n,d].set_ylim( 40.5,43)
        
        d=d+1
    n=n+1

cax,kw = mpl.colorbar.make_axes([ax for ax in axes.flat],label='Average simulated yield (t DM/ ha)')
cbar=plt.colorbar(axes[0][0].get_children()[0],cax=cax, **kw)
cbar.set_label( label='Average simulated yield (t DM/ ha)',size=30, labelpad=20,weight='bold')
cbar.ax.tick_params(labelsize=25) 
